load("@bazel_skylib//lib:paths.bzl", "paths")
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("@rules_cc//cc:defs.bzl", "cc_binary", "cc_library", "cc_test")
load("@rules_python//python:defs.bzl", "py_library", "py_test")
load("@pytorch//third_party:substitution.bzl", "header_template_rule", "template_rule")
load("@pytorch//:tools/bazel.bzl", "rules")
load("@pytorch//tools/rules:cu.bzl", "cu_library")
load("@pytorch//tools/config:defs.bzl", "if_cuda")
load("@pytorch//:aten.bzl", "generate_aten", "intern_build_aten_ops")
load(":build.bzl", "GENERATED_AUTOGRAD_CPP", "GENERATED_AUTOGRAD_PYTHON", "define_targets")
load(":build_variables.bzl", "jit_core_sources", "lazy_tensor_ts_sources", "libtorch_core_sources", "libtorch_cuda_sources", "libtorch_distributed_sources", "libtorch_extra_sources", "libtorch_python_core_sources", "torch_cpp_srcs", "libtorch_python_cuda_sources", "libtorch_python_distributed_sources")
load(":ufunc_defs.bzl", "aten_ufunc_generated_cpu_kernel_sources", "aten_ufunc_generated_cpu_sources", "aten_ufunc_generated_cuda_sources")
load("//:tools/bazel.bzl", "rules")

# Export files for use by torch/headeronly (where version.h generation now lives)
exports_files(["version.txt"])

define_targets(rules = rules)

COMMON_COPTS = [
    "-DHAVE_MALLOC_USABLE_SIZE=1",
    "-DHAVE_MMAP=1",
    "-DHAVE_SHM_OPEN=1",
    "-DHAVE_SHM_UNLINK=1",
    "-D_FILE_OFFSET_BITS=64",
    "-DUSE_FBGEMM",
    "-DUSE_DISTRIBUTED",
    "-DAT_PER_OPERATOR_HEADERS",
    "-DATEN_THREADING=NATIVE",
    "-DNO_CUDNN_DESTROY_HANDLE",
] + if_cuda([
    "-DUSE_CUDA",
    "-DUSE_CUDNN",
    # TODO: This should be passed only when building for CUDA-11.5 or newer
    # use cub in a safe manner, see:
    # https://github.com/pytorch/pytorch/pull/55292
    "-DCUB_WRAPPED_NAMESPACE=at_cuda_detail",
])

aten_generation_srcs = ["aten/src/ATen/native/native_functions.yaml"] + ["aten/src/ATen/native/tags.yaml"] + glob(["aten/src/ATen/templates/**"])

generated_cpu_cpp = [
    "aten/src/ATen/RegisterBackendSelect.cpp",
    "aten/src/ATen/RegisterCPU_0.cpp",
    "aten/src/ATen/RegisterCPU_1.cpp",
    "aten/src/ATen/RegisterCPU_2.cpp",
    "aten/src/ATen/RegisterCPU_3.cpp",
    "aten/src/ATen/RegisterFunctionalization_0.cpp",
    "aten/src/ATen/RegisterFunctionalization_1.cpp",
    "aten/src/ATen/RegisterFunctionalization_2.cpp",
    "aten/src/ATen/RegisterFunctionalization_3.cpp",
    # "aten/src/ATen/RegisterFunctionalizationEverything.cpp",
    "aten/src/ATen/RegisterMkldnnCPU_0.cpp",
    "aten/src/ATen/RegisterNestedTensorCPU_0.cpp",
    "aten/src/ATen/RegisterQuantizedCPU_0.cpp",
    "aten/src/ATen/RegisterSparseCPU_0.cpp",
    "aten/src/ATen/RegisterSparseCsrCPU_0.cpp",
    "aten/src/ATen/RegisterZeroTensor_0.cpp",
    "aten/src/ATen/RegisterCompositeImplicitAutograd_0.cpp",
    "aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor_0.cpp",
    "aten/src/ATen/RegisterCompositeExplicitAutograd_0.cpp",
    "aten/src/ATen/RegisterCompositeExplicitAutogradNonFunctional_0.cpp",
    "aten/src/ATen/RegisterMeta_0.cpp",
    "aten/src/ATen/RegisterSparseMeta_0.cpp",
    "aten/src/ATen/RegisterQuantizedMeta_0.cpp",
    "aten/src/ATen/RegisterNestedTensorMeta_0.cpp",
    "aten/src/ATen/RegisterSchema.cpp",
    "aten/src/ATen/CPUFunctions.h",
    "aten/src/ATen/CPUFunctions_inl.h",
    "aten/src/ATen/CompositeExplicitAutogradFunctions.h",
    "aten/src/ATen/CompositeExplicitAutogradFunctions_inl.h",
    "aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
    "aten/src/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
    "aten/src/ATen/CompositeImplicitAutogradFunctions.h",
    "aten/src/ATen/CompositeImplicitAutogradFunctions_inl.h",
    "aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions.h",
    "aten/src/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h",
    "aten/src/ATen/CompositeViewCopyKernels.cpp",
    "aten/src/ATen/FunctionalInverses.h",
    "aten/src/ATen/Functions.h",
    "aten/src/ATen/Functions.cpp",
    "aten/src/ATen/RedispatchFunctions.h",
    "aten/src/ATen/Operators.h",
    "aten/src/ATen/Operators_0.cpp",
    "aten/src/ATen/Operators_1.cpp",
    "aten/src/ATen/Operators_2.cpp",
    "aten/src/ATen/Operators_3.cpp",
    "aten/src/ATen/Operators_4.cpp",
    "aten/src/ATen/NativeFunctions.h",
    "aten/src/ATen/MetaFunctions.h",
    "aten/src/ATen/MetaFunctions_inl.h",
    "aten/src/ATen/MethodOperators.h",
    "aten/src/ATen/NativeMetaFunctions.h",
    "aten/src/ATen/RegistrationDeclarations.h",
    "aten/src/ATen/VmapGeneratedPlumbing.h",
    "aten/src/ATen/ViewMetaClasses.h",
    "aten/src/ATen/ViewMetaClasses.cpp",
    "aten/src/ATen/core/aten_interned_strings.h",
    "aten/src/ATen/core/enum_tag.h",
    "aten/src/ATen/core/TensorBody.h",
    "aten/src/ATen/core/TensorMethods.cpp",
    "aten/src/ATen/core/ATenOpList.cpp",
]

generated_cuda_cpp = [
    "aten/src/ATen/CUDAFunctions.h",
    "aten/src/ATen/CUDAFunctions_inl.h",
    "aten/src/ATen/RegisterCUDA_0.cpp",
    "aten/src/ATen/RegisterNestedTensorCUDA_0.cpp",
    "aten/src/ATen/RegisterQuantizedCUDA_0.cpp",
    "aten/src/ATen/RegisterSparseCUDA_0.cpp",
    "aten/src/ATen/RegisterSparseCsrCUDA_0.cpp",
]

generate_aten(
    name = "generated_aten_cpp",
    srcs = aten_generation_srcs,
    outs = (
        generated_cpu_cpp +
        generated_cuda_cpp +
        aten_ufunc_generated_cpu_sources("aten/src/ATen/{}") +
        aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}") +
        aten_ufunc_generated_cuda_sources("aten/src/ATen/{}") + [
            "aten/src/ATen/Declarations.yaml",
        ]
    ),
    generator = "//torchgen:gen",
)

filegroup(
    name = "cpp_generated_code",
    srcs = GENERATED_AUTOGRAD_CPP,
    data = [":generate-code"],
)

# ATen
filegroup(
    name = "aten_base_cpp",
    srcs = glob([
        "aten/src/ATen/*.cpp",
        "aten/src/ATen/functorch/*.cpp",
        "aten/src/ATen/detail/*.cpp",
        "aten/src/ATen/cpu/*.cpp",
    ]),
)

filegroup(
    name = "ATen_CORE_SRCS",
    srcs = glob(
        [
            "aten/src/ATen/core/**/*.cpp",
        ],
        exclude = [
            "aten/src/ATen/core/**/*_test.cpp",
        ],
    ),
)

filegroup(
    name = "aten_native_cpp",
    srcs = glob(["aten/src/ATen/native/*.cpp"]),
)

filegroup(
    name = "aten_native_sparse_cpp",
    srcs = glob(["aten/src/ATen/native/sparse/*.cpp"]),
)

filegroup(
    name = "aten_native_nested_cpp",
    srcs = glob(["aten/src/ATen/native/nested/*.cpp"]),
)

filegroup(
    name = "aten_native_quantized_cpp",
    srcs = glob(
        [
            "aten/src/ATen/native/quantized/*.cpp",
            "aten/src/ATen/native/quantized/cpu/*.cpp",
        ],
    ),
)

filegroup(
    name = "aten_native_transformers_cpp",
    srcs = glob(["aten/src/ATen/native/transformers/*.cpp"]),
)

filegroup(
    name = "aten_native_mkl_cpp",
    srcs = glob([
        "aten/src/ATen/native/mkl/*.cpp",
        "aten/src/ATen/mkl/*.cpp",
    ]),
)

filegroup(
    name = "aten_native_mkldnn_cpp",
    srcs = glob(["aten/src/ATen/native/mkldnn/*.cpp"]),
)

filegroup(
    name = "aten_native_xnnpack",
    srcs = glob(["aten/src/ATen/native/xnnpack/*.cpp"]),
)

filegroup(
    name = "aten_base_vulkan",
    srcs = glob(["aten/src/ATen/vulkan/*.cpp"]),
)

filegroup(
    name = "aten_base_metal",
    srcs = glob(["aten/src/ATen/metal/*.cpp"]),
)

filegroup(
    name = "ATen_QUANTIZED_SRCS",
    srcs = glob(
        [
            "aten/src/ATen/quantized/**/*.cpp",
        ],
        exclude = [
            "aten/src/ATen/quantized/**/*_test.cpp",
        ],
    ),
)

filegroup(
    name = "aten_cuda_cpp_srcs",
    srcs = glob(
        [
            "aten/src/ATen/cuda/*.cpp",
            "aten/src/ATen/cuda/detail/*.cpp",
            "aten/src/ATen/cuda/tunable/*.cpp",
            "aten/src/ATen/cudnn/*.cpp",
            "aten/src/ATen/native/cuda/*.cpp",
            "aten/src/ATen/native/cuda/linalg/*.cpp",
            "aten/src/ATen/native/cudnn/*.cpp",
            "aten/src/ATen/native/miopen/*.cpp",
            "aten/src/ATen/native/nested/cuda/*.cpp",
            "aten/src/ATen/native/quantized/cuda/*.cpp",
            "aten/src/ATen/native/quantized/cudnn/*.cpp",
            "aten/src/ATen/native/sparse/cuda/*.cpp",
            "aten/src/ATen/native/transformers/cuda/*.cpp",
        ],
    ),
)

filegroup(
    name = "aten_cu_srcs",
    srcs = glob([
        "aten/src/ATen/cuda/*.cu",
        "aten/src/ATen/cuda/detail/*.cu",
        "aten/src/ATen/native/cuda/*.cu",
        "aten/src/ATen/native/nested/cuda/*.cu",
        "aten/src/ATen/native/quantized/cuda/*.cu",
        "aten/src/ATen/native/sparse/cuda/*.cu",
        "aten/src/ATen/native/transformers/cuda/*.cu",
    ]) + aten_ufunc_generated_cuda_sources("aten/src/ATen/{}"),
    # It's a bit puzzling to me why it's not necessary to declare the
    # target that generates these sources...
)

# TODO: Enable support for KleidiAI bazel build
header_template_rule(
    name = "aten_src_ATen_config",
    src = "aten/src/ATen/Config.h.in",
    out = "aten/src/ATen/Config.h",
    include = "aten/src",
    substitutions = {
        "@AT_MKLDNN_ENABLED@": "1",
        "@AT_MKLDNN_ACL_ENABLED@": "0",
        "@AT_MKL_ENABLED@": "1",
        "@AT_MKL_SEQUENTIAL@": "0",
        "@AT_POCKETFFT_ENABLED@": "0",
        "@AT_NNPACK_ENABLED@": "0",
        "@CAFFE2_STATIC_LINK_CUDA_INT@": "0",
        "@AT_BUILD_WITH_BLAS@": "1",
        "@AT_BUILD_WITH_LAPACK@": "1",
        "@AT_PARALLEL_OPENMP@": "0",
        "@AT_PARALLEL_NATIVE@": "1",
        "@AT_BLAS_F2C@": "0",
        "@AT_BLAS_USE_CBLAS_DOT@": "1",
        "@AT_KLEIDIAI_ENABLED@": "0",
        "@AT_USE_EIGEN_SPARSE@": "0",
    },
)

header_template_rule(
    name = "aten_src_ATen_cuda_config",
    src = "aten/src/ATen/cuda/CUDAConfig.h.in",
    out = "aten/src/ATen/cuda/CUDAConfig.h",
    include = "aten/src",
    substitutions = {
        "@AT_CUDNN_ENABLED@": "1",
        "@AT_CUSPARSELT_ENABLED@": "0",
        "@AT_HIPSPARSELT_ENABLED@": "0",
        "@AT_ROCM_ENABLED@": "0",
        "@AT_MAGMA_ENABLED@": "0",
        "@NVCC_FLAGS_EXTRA@": "",
    },
)

cc_library(
    name = "aten_headers",
    hdrs = [
        "torch/csrc/Export.h",
        "torch/csrc/jit/frontend/function_schema_parser.h",
    ] + glob(
        [
            "aten/src/**/*.h",
            "aten/src/**/*.hpp",
            "aten/src/ATen/cuda/**/*.cuh",
            "aten/src/ATen/native/**/*.cuh",
            "aten/src/THC/*.cuh",
        ],
    ) + [
        ":aten_src_ATen_config",
        ":generated_aten_cpp",
    ],
    includes = [
        "aten/src",
    ],
    deps = [
        "//c10",
    ],
)

ATEN_COPTS = COMMON_COPTS + [
    "-DCAFFE2_BUILD_MAIN_LIBS",
    "-DHAVE_AVX_CPU_DEFINITION",
    "-DHAVE_AVX2_CPU_DEFINITION",
    "-fvisibility-inlines-hidden",
    "-fno-math-errno",
    "-fno-trapping-math",
]

intern_build_aten_ops(
    copts = ATEN_COPTS,
    extra_impls = aten_ufunc_generated_cpu_kernel_sources("aten/src/ATen/{}"),
    deps = [
        ":aten_headers",
        "@fbgemm",
        "@mkl",
        "@sleef",
        "@mkl_dnn//:mkl-dnn",
    ],
)

cc_library(
    name = "aten",
    srcs = [
        ":ATen_CORE_SRCS",
        ":ATen_QUANTIZED_SRCS",
        ":aten_base_cpp",
        ":aten_base_metal",
        ":aten_base_vulkan",
        ":aten_native_cpp",
        ":aten_native_mkl_cpp",
        ":aten_native_mkldnn_cpp",
        ":aten_native_nested_cpp",
        ":aten_native_quantized_cpp",
        ":aten_native_sparse_cpp",
        ":aten_native_transformers_cpp",
        ":aten_native_xnnpack",
        ":aten_src_ATen_config",
    ] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"),
    copts = ATEN_COPTS,
    linkopts = [
      "-ldl",
    ],
    data = if_cuda(
        [":libcaffe2_nvrtc.so"],
        [],
    ),
    visibility = ["//visibility:public"],
    deps = [
        ":ATen_CPU",
        ":aten_headers",
        ":caffe2_for_aten_headers",
        ":torch_headers",
        "@fbgemm",
        "@ideep",
        "@nlohmann",
    ],
    alwayslink = True,
)

cc_library(
    name = "aten_nvrtc",
    srcs = glob([
        "aten/src/ATen/cuda/nvrtc_stub/*.cpp",
    ]),
    copts = ATEN_COPTS,
    linkstatic = True,
    visibility = ["//visibility:public"],
    deps = [
        ":aten_headers",
        "//c10",
        "@cuda",
        "@cuda//:cuda_driver",
        "@cuda//:nvrtc",
    ],
    alwayslink = True,
)

cc_binary(
    name = "libcaffe2_nvrtc.so",
    linkshared = True,
    visibility = ["//visibility:public"],
    deps = [
        ":aten_nvrtc",
    ],
)

cc_library(
    name = "aten_cuda_cpp",
    srcs = [":aten_cuda_cpp_srcs"] + generated_cuda_cpp,
    hdrs = [":aten_src_ATen_cuda_config"],
    copts = ATEN_COPTS,
    visibility = ["//visibility:public"],
    deps = [
        ":aten",
        "@cuda",
        "@cuda//:cusolver",
        "@cuda//:nvrtc",
        "@cudnn",
        "@cudnn_frontend",
    ],
    alwayslink = True,
)

torch_cuda_half_options = [
    "-DCUDA_HAS_FP16=1",
    "-D__CUDA_NO_HALF_OPERATORS__",
    "-D__CUDA_NO_HALF_CONVERSIONS__",
    "-D__CUDA_NO_BFLOAT16_CONVERSIONS__",
    "-D__CUDA_NO_HALF2_OPERATORS__",
]

cu_library(
    name = "aten_cuda",
    srcs = [":aten_cu_srcs"],
    copts = ATEN_COPTS + torch_cuda_half_options,
    visibility = ["//visibility:public"],
    deps = [
        ":aten_cuda_cpp",
        "//c10/util:bit_cast",
        "@cuda//:cublas",
        "@cuda//:cufft",
        "@cuda//:cusparse",
        "@cutlass",
    ],
    alwayslink = True,
)

# caffe2
CAFFE2_COPTS = COMMON_COPTS + [
    "-Dcaffe2_EXPORTS",
    "-DCAFFE2_USE_CUDNN",
    "-DCAFFE2_BUILD_MAIN_LIB",
    "-fvisibility-inlines-hidden",
    "-fno-math-errno",
    "-fno-trapping-math",
]

filegroup(
    name = "caffe2_core_srcs",
    srcs = [
        "caffe2/core/common.cc",
    ],
)

filegroup(
    name = "caffe2_perfkernels_srcs",
    srcs = [
        "caffe2/perfkernels/embedding_lookup_idx.cc",
    ],
)


filegroup(
    name = "caffe2_serialize_srcs",
    srcs = [
        "caffe2/serialize/file_adapter.cc",
        "caffe2/serialize/inline_container.cc",
        "caffe2/serialize/istream_adapter.cc",
        "caffe2/serialize/read_adapter_interface.cc",
    ],
)

filegroup(
    name = "caffe2_utils_srcs",
    srcs = [
        "caffe2/utils/proto_wrap.cc",
        "caffe2/utils/string_utils.cc",
        "caffe2/utils/threadpool/ThreadPool.cc",
        "caffe2/utils/threadpool/pthreadpool.cc",
        "caffe2/utils/threadpool/pthreadpool_impl.cc",
        "caffe2/utils/threadpool/thread_pool_guard.cpp",
    ],
)

# To achieve finer granularity and make debug easier, caffe2 is split into three libraries:
# ATen, caffe2 and caffe2_for_aten_headers. ATen lib group up source codes under
# aten/ directory and caffe2 contains most files under `caffe2/` directory. Since the
# ATen lib and the caffe2 lib would depend on each other, `caffe2_for_aten_headers` is split
# out from `caffe2` to avoid dependency cycle.
cc_library(
    name = "caffe2_for_aten_headers",
    hdrs = [
        "caffe2/core/common.h",
        "caffe2/perfkernels/common.h",
        "caffe2/perfkernels/embedding_lookup_idx.h",
        "caffe2/utils/fixed_divisor.h",
    ] + glob([
        "caffe2/utils/threadpool/*.h",
    ]),
    copts = CAFFE2_COPTS,
    visibility = ["//visibility:public"],
    deps = [
        ":caffe2_core_macros",
        "//c10",
    ],
)

cc_library(
    name = "caffe2_headers",
    hdrs = glob(
        [
            "caffe2/perfkernels/*.h",
            "caffe2/serialize/*.h",
            "caffe2/utils/*.h",
            "caffe2/utils/threadpool/*.h",
            "modules/**/*.h",
        ],
        exclude = [
            "caffe2/core/macros.h",
        ],
    ) + if_cuda(glob([
        "caffe2/**/*.cuh",
    ])),
    copts = CAFFE2_COPTS,
    visibility = ["//visibility:public"],
    deps = [
        ":caffe2_core_macros",
        ":caffe2_for_aten_headers",
    ],
)

cc_library(
    name = "caffe2",
    srcs = [
        ":caffe2_core_srcs",
        ":caffe2_perfkernels_srcs",
        ":caffe2_serialize_srcs",
        ":caffe2_utils_srcs",
    ],
    copts = CAFFE2_COPTS + ["-mf16c"],
    linkstatic = 1,
    visibility = ["//visibility:public"],
    deps = [
        ":caffe2_core_macros",
        ":caffe2_headers",
        ":caffe2_perfkernels_avx",
        ":caffe2_perfkernels_avx2",
        "//third_party/miniz-3.0.2:miniz",
        "@com_google_protobuf//:protobuf",
        "@eigen",
        "@fbgemm//:fbgemm_src_headers",
        "@fmt",
        "@onnx",
    ] + if_cuda(
        [
            ":aten_cuda",
            "@tensorpipe//:tensorpipe_cuda",
        ],
        [
            ":aten",
            "@tensorpipe//:tensorpipe_cpu",
        ],
    ),
    alwayslink = True,
)

cu_library(
    name = "torch_cuda",
    srcs = [
        "torch/csrc/distributed/c10d/NanCheck.cu",
        "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
        "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu",
    ],
    copts = torch_cuda_half_options,
    visibility = ["//visibility:public"],
    deps = [
        ":aten",
        "@cuda//:cublas",
        "@cuda//:curand",
        "@cudnn",
        "@eigen",
        "@tensorpipe//:tensorpipe_cuda",
    ],
    alwayslink = True,
)

PERF_COPTS = [
    "-DHAVE_AVX_CPU_DEFINITION",
    "-DHAVE_AVX2_CPU_DEFINITION",
    "-DENABLE_ALIAS=1",
    "-DHAVE_MALLOC_USABLE_SIZE=1",
    "-DHAVE_MMAP=1",
    "-DHAVE_SHM_OPEN=1",
    "-DHAVE_SHM_UNLINK=1",
    "-DSLEEF_STATIC_LIBS=1",
    "-DTH_BALS_MKL",
    "-D_FILE_OFFSET_BITS=64",
    "-DUSE_FBGEMM",
    "-fvisibility-inlines-hidden",
    "-Wunused-parameter",
    "-fno-math-errno",
    "-fno-trapping-math",
    "-mf16c",
]

PERF_HEADERS = glob([
    "caffe2/perfkernels/*.h",
    "caffe2/core/*.h",
])

cc_library(
    name = "caffe2_perfkernels_avx",
    srcs = glob([
        "caffe2/perfkernels/*_avx.cc",
    ]),
    hdrs = PERF_HEADERS,
    copts = PERF_COPTS + [
        "-mavx",
    ],
    visibility = ["//visibility:public"],
    deps = [
        ":caffe2_headers",
        "//c10",
    ],
    alwayslink = True,
)

cc_library(
    name = "caffe2_perfkernels_avx2",
    srcs = glob([
        "caffe2/perfkernels/*_avx2.cc",
    ]),
    hdrs = PERF_HEADERS,
    copts = PERF_COPTS + [
        "-mavx2",
        "-mfma",
        "-mavx",
    ],
    visibility = ["//visibility:public"],
    deps = [
        ":caffe2_headers",
        "//c10",
    ],
    alwayslink = True,
)

# torch
torch_cuda_headers = glob(["torch/csrc/cuda/*.h"])

flatbuffer_cc_library(
    name = "torch_flatbuffers",
    srcs = [
        "torch/csrc/jit/serialization/mobile_bytecode.fbs",
    ],
    flatc_args = ["--cpp", "--gen-mutable", "--scoped-enums"],
    out_prefix = "torch/csrc/jit/serialization/",
)

cc_library(
    name = "torch_headers",
    hdrs = if_cuda(
        torch_cuda_headers,
    ) + glob(
        [
            "torch/*.h",
            "torch/csrc/**/*.h",
            "torch/nativert/**/*.h",
            "torch/csrc/distributed/c10d/**/*.hpp",
            "torch/lib/libshm/*.h",
        ],
        exclude = [
            "torch/csrc/*/generated/*.h",
            "torch/csrc/jit/serialization/mobile_bytecode_generated.h",
        ] + torch_cuda_headers,
    ) + GENERATED_AUTOGRAD_CPP + [
        "//torch/headeronly:version_h",
    ],
    includes = [
        "third_party/kineto/libkineto/include",
        "torch/csrc",
        "torch/csrc/api/include",
        "torch/csrc/distributed",
        "torch/lib",
        "torch/lib/libshm",
    ],
    visibility = ["//visibility:public"],
    deps = [
        ":aten_headers",
        ":caffe2_headers",
        ":torch_flatbuffers",
        "//c10",
        "@com_github_google_flatbuffers//:flatbuffers",
        "@local_config_python//:python_headers",
        "@onnx",
    ],
    alwayslink = True,
)

TORCH_COPTS = COMMON_COPTS + [
    "-Dtorch_EXPORTS",
    "-DHAVE_AVX_CPU_DEFINITION",
    "-DHAVE_AVX2_CPU_DEFINITION",
    "-DCAFFE2_USE_GLOO",
    "-fvisibility-inlines-hidden",
    "-fno-math-errno ",
    "-fno-trapping-math",
    "-Wno-error=unused-function",
]

torch_sources = {
    k: ""
    for k in (
        libtorch_core_sources +
        libtorch_distributed_sources +
        torch_cpp_srcs +
        libtorch_extra_sources +
        jit_core_sources +
        lazy_tensor_ts_sources +
        GENERATED_AUTOGRAD_CPP
    )
}.keys()

cc_library(
    name = "torch",
    srcs = if_cuda(glob(
        libtorch_cuda_sources,
        exclude = [
            "torch/csrc/cuda/nccl.cpp",
            "torch/csrc/cuda/python_nccl.cpp",
            "torch/csrc/distributed/c10d/NanCheck.cu",
            "torch/csrc/distributed/c10d/cuda/AsyncMM.cu",
            "torch/csrc/distributed/c10d/quantization/quantization_gpu.cu",
            "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu",
            "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryOps.cu",
            "torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemoryUtils.cpp",
            "torch/csrc/distributed/c10d/symm_mem/cuda_mem_pool.cpp",
            "torch/csrc/distributed/c10d/symm_mem/intra_node_comm.cu",
        ],
    )) + torch_sources,
    copts = TORCH_COPTS,
    linkopts = [
      "-lrt",
    ],
    defines = [
        "CAFFE2_NIGHTLY_VERSION=20200115",
    ],
    visibility = ["//visibility:public"],
    deps = [
        ":caffe2",
        ":torch_headers",
        "@kineto",
        "@cpp-httplib",
        "@nlohmann",
    ] + if_cuda([
        "@cuda//:nvToolsExt",
        "@cutlass",
        ":torch_cuda",
    ]),
    alwayslink = True,
)

cc_library(
    name = "shm",
    srcs = glob(["torch/lib/libshm/*.cpp"]),
    linkopts = [
      "-lrt",
    ],
    deps = [
        ":torch",
    ],
)

cc_library(
    name = "libtorch_headers",
    hdrs = glob([
        "**/*.h",
        "**/*.cuh",
    ]) + [
        # We need the filegroup here because the raw list causes Bazel
        # to see duplicate files. It knows how to deduplicate with the
        # filegroup.
        ":cpp_generated_code",
    ],
    includes = [
        "torch/csrc/api/include",
        "torch/csrc/distributed",
        "torch/lib",
        "torch/lib/libshm",
    ],
    visibility = ["//visibility:public"],
    deps = [
        ":torch_headers",
    ],
)

cc_library(
    name = "torch_python",
    srcs = libtorch_python_core_sources
        + if_cuda(libtorch_python_cuda_sources)
        + if_cuda(libtorch_python_distributed_sources)
        + GENERATED_AUTOGRAD_PYTHON,
    hdrs = glob([
        "torch/csrc/generic/*.cpp",
    ]),
    copts = COMMON_COPTS + if_cuda(["-DUSE_CUDA=1"]),
    deps = [
        ":torch",
        ":shm",
        "@pybind11",
    ],
)

pybind_extension(
    name = "torch/_C",
    srcs = ["torch/csrc/stub.c"],
    deps = [
        ":torch_python",
        ":aten_nvrtc",
    ],
)

cc_binary(
    name = "torch/bin/torch_shm_manager",
    srcs = [
        "torch/lib/libshm/manager.cpp",
    ],
    deps = [
        ":shm",
    ],
    linkstatic = False,
)

template_rule(
    name = "gen_version_py",
    src = ":torch/version.py.tpl",
    out = "torch/version.py",
    substitutions = if_cuda({
        # Set default to 11.2. Otherwise Torchvision complains about incompatibility.
        "{{CUDA_VERSION}}": "11.2",
        "{{VERSION}}": "2.0.0",
    }, {
        "{{CUDA_VERSION}}": "None",
        "{{VERSION}}": "2.0.0",
    }),
)

py_library(
    name = "pytorch_py",
    visibility = ["//visibility:public"],
    srcs = glob(["torch/**/*.py"], exclude = ["torch/version.py"]) + [":torch/version.py"] + glob(["functorch/**/*.py"]),
    deps = [
        rules.requirement("numpy"),
        rules.requirement("pyyaml"),
        rules.requirement("requests"),
        rules.requirement("setuptools"),
        rules.requirement("sympy"),
        rules.requirement("typing_extensions"),
        "//torchgen",
    ],
    data = [
        ":torch/_C.so",
        ":torch/bin/torch_shm_manager",
    ],
)

# cpp api tests
cc_library(
    name = "test_support",
    testonly = True,
    srcs = [
        "test/cpp/api/support.cpp",
    ],
    hdrs = [
        "test/cpp/api/init_baseline.h",
        "test/cpp/api/optim_baseline.h",
        "test/cpp/api/support.h",
        "test/cpp/common/support.h",
    ],
    deps = [
        ":torch",
        "@com_google_googletest//:gtest_main",
    ],
)

# Torch integration tests rely on a labeled data set from the MNIST database.
# http://yann.lecun.com/exdb/mnist/

cpp_api_tests = glob(
    ["test/cpp/api/*.cpp"],
    exclude = [
        "test/cpp/api/imethod.cpp",
        "test/cpp/api/integration.cpp",
    ],
)

cc_test(
    name = "integration_test",
    size = "medium",
    srcs = ["test/cpp/api/integration.cpp"],
    data = [
        ":download_mnist",
    ],
    tags = [
        "gpu-required",
    ],
    deps = [
        ":test_support",
        "@com_google_googletest//:gtest_main",
    ],
)

[
    cc_test(
        name = paths.split_extension(paths.basename(filename))[0].replace("-", "_") + "_test",
        size = "medium",
        srcs = [filename],
        deps = [
            ":test_support",
            "@com_google_googletest//:gtest_main",
        ],
    )
    for filename in cpp_api_tests
]

test_suite(
    name = "api_tests",
    tests = [
        "any_test",
        "autograd_test",
        "dataloader_test",
        "enum_test",
        "expanding_array_test",
        "functional_test",
        "init_test",
        "integration_test",
        "jit_test",
        "memory_test",
        "misc_test",
        "module_test",
        "modulelist_test",
        "modules_test",
        "nn_utils_test",
        "optim_test",
        "ordered_dict_test",
        "rnn_test",
        "sequential_test",
        "serialize_test",
        "static_test",
        "tensor_options_test",
        "tensor_test",
        "torch_include_test",
    ],
)

# dist autograd tests
cc_test(
    name = "torch_dist_autograd_test",
    size = "small",
    srcs = ["test/cpp/dist_autograd/test_dist_autograd.cpp"],
    tags = [
        "exclusive",
        "gpu-required",
    ],
    deps = [
        ":torch",
        "@com_google_googletest//:gtest_main",
    ],
)

# jit tests
# Because these individual unit tests require custom registering,
# it is easier to mimic the cmake build by globing together a single test.
cc_test(
    name = "jit_tests",
    size = "small",
    srcs = glob(
        [
            "test/cpp/jit/*.cpp",
            "test/cpp/jit/*.h",
            "test/cpp/tensorexpr/*.cpp",
            "test/cpp/tensorexpr/*.h",
        ],
        exclude = [
            # skip this since <pybind11/embed.h> is not found in OSS build
            "test/cpp/jit/test_exception.cpp",
        ],
    ),
    linkstatic = True,
    tags = [
        "exclusive",
        "gpu-required",
    ],
    deps = [
        ":torch",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_test(
    name = "lazy_tests",
    size = "small",
    srcs = glob(
        [
            "test/cpp/lazy/*.cpp",
            "test/cpp/lazy/*.h",
        ],
        exclude = [
            # skip these since they depend on generated LazyIr.h which isn't available in bazel yet
            "test/cpp/lazy/test_ir.cpp",
            "test/cpp/lazy/test_lazy_ops.cpp",
            "test/cpp/lazy/test_lazy_ops_util.cpp",
            "test/cpp/lazy/test_lazy_graph_executor.cpp",
        ],
    ),
    linkstatic = True,
    tags = [
        "exclusive",
    ],
    deps = [
        ":torch",
        "@com_google_googletest//:gtest_main",
    ],
)

# python api tests

py_test(
    name = "test_bazel",
    srcs = ["test/_test_bazel.py"],
    main = "test/_test_bazel.py",
    deps = [
        ":pytorch_py",
        rules.requirement("networkx"),
    ],
)

# all tests
test_suite(
    name = "all_tests",
    tests = [
        "api_tests",
        "jit_tests",
        "torch_dist_autograd_test",
        "//c10/test:tests",
    ],
)

# An internal genrule that we are converging with refers to these file
# as if they are from this package, so we alias them for
# compatibility.

[
    alias(
        name = paths.basename(path),
        actual = path,
    )
    for path in [
        "aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
        "aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
        "aten/src/ATen/templates/LazyIr.h",
        "aten/src/ATen/templates/LazyNonNativeIr.h",
        "aten/src/ATen/templates/RegisterDispatchKey.cpp",
        "aten/src/ATen/templates/RegisterDispatchDefinitions.ini",
        "aten/src/ATen/templates/ViewMetaClassesPythonBinding.cpp",
        "aten/src/ATen/native/native_functions.yaml",
        "aten/src/ATen/native/tags.yaml",
        "aten/src/ATen/native/ts_native_functions.yaml",
        "torch/csrc/lazy/core/shape_inference.h",
        "torch/csrc/lazy/ts_backend/ts_native_functions.cpp",
    ]
]

genrule(
    name = "download_mnist",
    srcs = ["//:tools/download_mnist.py"],
    outs = [
        "mnist/train-images-idx3-ubyte",
        "mnist/train-labels-idx1-ubyte",
        "mnist/t10k-images-idx3-ubyte",
        "mnist/t10k-labels-idx1-ubyte",
    ],
    cmd = "python3 tools/download_mnist.py -d $(RULEDIR)/mnist",
)
